{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyPdCDD5x0HhXnMYznmSeNOD",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/generic_tools.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!pip --quiet install neo4j langchain-core langchain-community"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "WnnZVyJbdtit",
        "outputId": "759e4973-1d5b-485a-9b9b-b8a3bafa5fc1"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m41.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m72.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.5/49.5 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from langchain_community.graphs import Neo4jGraph\n",
        "\n",
        "os.environ[\"NEO4J_URI\"] = \"neo4j+s://demo.neo4jlabs.com\"\n",
        "os.environ[\"NEO4J_USERNAME\"] = \"recommendations\"\n",
        "os.environ[\"NEO4J_PASSWORD\"] = \"recommendations\"\n",
        "os.environ[\"NEO4J_DATABASE\"] = \"recommendations\"\n",
        "\n",
        "\n",
        "graph = Neo4jGraph(refresh_schema=False)"
      ],
      "metadata": {
        "id": "jx9uj95BeRnU"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from langchain_core.tools import StructuredTool"
      ],
      "metadata": {
        "id": "1GrKSrN1gsX4"
      },
      "execution_count": 40,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 59,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XKTnJ54jdX4z",
        "outputId": "9b7d7bb8-93f3-4e6c-cc0c-9a0b37d9c8b7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cypher: MATCH (n:Movie) WHERE n.year >= $min_year AND n.year <= $max_year RETURN count(n) AS movie_count\n",
            "Params: {'min_year': 2000, 'max_year': 2020, 'min_rating': 7.5}\n"
          ]
        }
      ],
      "source": [
        "from typing import Any, Callable, Dict, List, Optional, Union\n",
        "from functools import wraps\n",
        "import inspect\n",
        "\n",
        "def create_filter_function(\n",
        "    node_label: str,\n",
        "    properties: Dict[str, type],\n",
        "    count_field: str = \"count\",\n",
        "    grouping_allowed: bool = True\n",
        ") -> Callable:\n",
        "    \"\"\"\n",
        "    Dynamically creates a filter function based on node properties.\n",
        "\n",
        "    Args:\n",
        "        node_label: The Neo4j node label\n",
        "        properties: Dictionary of property names and their types\n",
        "        count_field: Name of the count field in the result\n",
        "        grouping_allowed: Whether grouping by properties is allowed\n",
        "    \"\"\"\n",
        "\n",
        "    def generate_type_hints() -> Dict[str, Any]:\n",
        "        \"\"\"Generate type hints for the function parameters\"\"\"\n",
        "        hints = {}\n",
        "        for prop_name, prop_type in properties.items():\n",
        "            hints[f\"min_{prop_name}\"] = Optional[prop_type]\n",
        "            hints[f\"max_{prop_name}\"] = Optional[prop_type]\n",
        "        if grouping_allowed:\n",
        "            hints[\"grouping_key\"] = Optional[str]\n",
        "        return hints\n",
        "\n",
        "    def generate_parameters() -> List[inspect.Parameter]:\n",
        "        \"\"\"Generate function parameters\"\"\"\n",
        "        params = []\n",
        "        for prop_name, prop_type in properties.items():\n",
        "            if prop_type in (int, float):  # Only numeric types get min/max filters\n",
        "                params.extend([\n",
        "                    inspect.Parameter(\n",
        "                        f\"min_{prop_name}\",\n",
        "                        inspect.Parameter.POSITIONAL_OR_KEYWORD,\n",
        "                        annotation=Optional[prop_type],\n",
        "                        default=None\n",
        "                    ),\n",
        "                    inspect.Parameter(\n",
        "                        f\"max_{prop_name}\",\n",
        "                        inspect.Parameter.POSITIONAL_OR_KEYWORD,\n",
        "                        annotation=Optional[prop_type],\n",
        "                        default=None\n",
        "                    )\n",
        "                ])\n",
        "        if grouping_allowed:\n",
        "            params.append(\n",
        "                inspect.Parameter(\n",
        "                    \"grouping_key\",\n",
        "                    inspect.Parameter.POSITIONAL_OR_KEYWORD,\n",
        "                    annotation=Optional[str],\n",
        "                    default=None\n",
        "                )\n",
        "            )\n",
        "        return params\n",
        "\n",
        "    def create_filter_conditions(kwargs: Dict[str, Any]) -> List[tuple]:\n",
        "        \"\"\"Create filter conditions based on the provided arguments\"\"\"\n",
        "        filters = []\n",
        "        for prop_name, prop_type in properties.items():\n",
        "            if prop_type in (int, float):\n",
        "                min_val = kwargs.get(f\"min_{prop_name}\")\n",
        "                max_val = kwargs.get(f\"max_{prop_name}\")\n",
        "                if min_val is not None:\n",
        "                    filters.append((f\"n.{prop_name} >= $min_{prop_name}\", min_val))\n",
        "                if max_val is not None:\n",
        "                    filters.append((f\"n.{prop_name} <= $max_{prop_name}\", max_val))\n",
        "        return filters\n",
        "\n",
        "    def filter_function(*args, **kwargs) -> List[Dict]:\n",
        "        \"\"\"The dynamically generated filter function\"\"\"\n",
        "        filters = create_filter_conditions(kwargs)\n",
        "\n",
        "        # Create parameters dictionary for Neo4j query\n",
        "        params = {}\n",
        "        for param_name, value in kwargs.items():\n",
        "            if value is not None and param_name != \"grouping_key\":\n",
        "                # Keep the full parameter name (min_year, max_year, etc.)\n",
        "                params[param_name] = value\n",
        "\n",
        "        # Build Cypher query\n",
        "        where_clause = \" AND \".join(condition for condition, _ in filters)\n",
        "\n",
        "        cypher_statement = f\"MATCH (n:{node_label}) \"\n",
        "        if where_clause:\n",
        "            cypher_statement += f\"WHERE {where_clause} \"\n",
        "\n",
        "        grouping_key = kwargs.get(\"grouping_key\")\n",
        "        return_clause = (\n",
        "            f\"n.`{grouping_key}`, count(n) AS {count_field}\"\n",
        "            if grouping_key\n",
        "            else f\"count(n) AS {count_field}\"\n",
        "        )\n",
        "\n",
        "        cypher_statement += f\"RETURN {return_clause}\"\n",
        "\n",
        "        if grouping_key:\n",
        "            cypher_statement += f\" ORDER BY n.`{grouping_key}`\"\n",
        "\n",
        "        # For debugging\n",
        "        print(f\"Cypher: {cypher_statement}\")\n",
        "        print(f\"Params: {params}\")\n",
        "\n",
        "        # Execute query\n",
        "        return graph.query(cypher_statement, params=params)\n",
        "\n",
        "    # Create the function with the correct signature\n",
        "    sig = inspect.Signature(parameters=generate_parameters())\n",
        "    filter_function.__signature__ = sig\n",
        "    filter_function.__annotations__ = generate_type_hints()\n",
        "\n",
        "    return filter_function\n",
        "\n",
        "# Example usage:\n",
        "# Define node properties\n",
        "movie_properties = {\n",
        "    \"year\": int,\n",
        "    \"imdbRating\": float,\n",
        "    \"title\": str,\n",
        "}\n",
        "\n",
        "# Create the filter function\n",
        "movie_count = create_filter_function(\n",
        "    node_label=\"Movie\",\n",
        "    properties=movie_properties,\n",
        "    count_field=\"movie_count\"\n",
        ")\n",
        "\n",
        "# Use the generated function\n",
        "results = movie_count(\n",
        "    min_year=2000,\n",
        "    max_year=2020,\n",
        "    min_rating=7.5,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "results"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "r-g_28Ikd-nI",
        "outputId": "5b42f809-5eb7-42ab-9dd8-237094c9c0fe"
      },
      "execution_count": 60,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'movie_count': 3898}]"
            ]
          },
          "metadata": {},
          "execution_count": 60
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "bidC3HYSivnw"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}